
# coding: utf-8

# # GAN 
# 

# In[1]:


import sys
sys.executable
import numpy as np


# ## Imports

# In[2]:


get_ipython().run_line_magic('pylab', 'inline')
import tensorflow as tf
import itertools
import seaborn as sns
slim = tf.contrib.slim
ds = tf.contrib.distributions
import tensorflow.contrib.learn as tf_learn
import numpy as np
import matplotlib.pyplot as plt


# ## Parameters

# In[3]:


params = {
    'batch_size': 500,
    'latent_dim': 2, 
    'input_dim': 16, 
    'n_hidden_disc': 128,
    'n_hidden_gen': 128,
    'dataset': 'ring', # choose one from ['grid', 'ring', 'spiral']
    'number_saved_snapshots': 10,
    'switch_every_steps': 999,
    'across_epochs': 5
}

switch_every=(1000*params['across_epochs'])/params['number_saved_snapshots']
num_reservoir=params['across_epochs']


# ## Generate 2D Data

# In[4]:


def archimedes_spiral(theta, theta_offset=0., *args, **kwargs):
    """Return Archimedes spiral
    Args:
    theta: array-like, angles from polar coordinates to be converted
    theta_offset: float, angle offset in radians (2*pi = 0)
    """
    x, y = theta * np.cos(theta + theta_offset), theta * np.sin(
      theta + theta_offset)
    x_norm = np.max(np.abs(x))
    y_norm = np.max(np.abs(y))
    x, y = x / x_norm, y / y_norm
    return x, y

def create_distribution(num_components=25, num_features=2, dataset='grid', radius=1.0, **kwargs):
    cat = ds.Categorical(tf.zeros(num_components, dtype=float32))
    s = 0.01
    sigmas = [np.array([s, s]).astype(float32) for i in range(num_components)]
    
    if (dataset == 'grid'):
        mus = np.array([np.array([i, j]) for i, j in itertools.product(range(-4, 5, 2),
                                                           range(-4, 5, 2))],dtype=float32)
    elif (dataset == 'ring'):      
        thetas = np.linspace(0, 2*np.pi, num_components+1)[:-1]
        xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
        mus = zip(xs, ys)
        
    elif (dataset == 'spiral'):
        n_loops=2
        linspace = np.linspace(0, 2 * n_loops * np.pi, num_components)
        spir_x = np.empty(0, dtype=np.int32)
        spir_y = np.empty(0, dtype=np.int32)

        base_cos, base_sin = archimedes_spiral(linspace, 0 * np.pi)
        spir_x = np.append(spir_x, base_cos)
        spir_y = np.append(spir_y, base_sin)
        mus = np.float32(np.vstack((spir_x, spir_y)).T)
    components = list((ds.MultivariateNormalDiag(mu, sigma) 
                       for (mu, sigma) in zip(mus, sigmas)))
    data = ds.Mixture(cat, components)
    return data, mus


# ## Network definitions

# In[5]:


# generator for gan 
def generator(z, latent_dim, n_hidden, scope="generator"):
    with tf.variable_scope(scope):
        h = slim.fully_connected(z, n_hidden, activation_fn=tf.nn.relu)
        h = slim.fully_connected(h, n_hidden, activation_fn=tf.nn.relu)
        x = slim.fully_connected(h, latent_dim, activation_fn=None, scope="x_g")
    return x

# discriminator for gan
def discriminator(x, n_hidden=128, reuse=False, activation_fn=None, scope="discriminator"):
    with tf.variable_scope(scope) as vs:
        if reuse:
            vs.reuse_variables()
        h = slim.fully_connected(x, n_hidden, activation_fn=tf.nn.relu)
        log_d = slim.fully_connected(h, 1, activation_fn=activation_fn)
    return tf.squeeze(log_d, squeeze_dims=[1])


# ## Construct model and training ops

# In[6]:


tf.reset_default_graph()
#tf.set_random_seed(1234)
#np.random.seed(seed=1234)

# create real data
if (params['dataset']=='grid'):
    components = 25
if (params['dataset']=='ring'):
    components = 8
if (params['dataset']=='spiral'):
    components = 20
real_data, real_mus = create_distribution(components, dataset=params['dataset'])
# z is the noise that goes to the generator; z \sim p(z)
z = tf.random_normal([params['batch_size'], params['input_dim']]) 
x_g = generator(z, params['latent_dim'], params['n_hidden_gen'], scope="generator1")

x_r = real_data.sample(params['batch_size'])

# D(real)
d_real = discriminator(x_r, n_hidden=params['n_hidden_disc'], scope="discriminator1")
# D(fake)
d_fake = discriminator(x_g, n_hidden=params['n_hidden_disc'], scope="discriminator1", reuse=True)

generators = []
generators.append(x_g)
for_d_loss = []
discriminators_real = []
discriminators_fake = []
    
disc_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real, labels=tf.ones_like(d_real)) +
    tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake, labels=tf.zeros_like(d_fake)))
# saturating update rule
# gen_loss = -tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake, labels=tf.zeros_like(d_fake)))
# non-saturating update rule
gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake, labels=tf.ones_like(d_fake)))

plc_float = tf.placeholder(tf.float32)
disc_loss_calc = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real, labels=tf.ones_like(d_real)) +
    tf.nn.sigmoid_cross_entropy_with_logits(logits=plc_float, labels=tf.zeros_like(plc_float)))
gen_loss_calc = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=plc_float, labels=tf.ones_like(plc_float)))

        
qvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")
dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
opt_g = tf.train.AdamOptimizer(1e-3, beta1=.5)
opt_d = tf.train.AdamOptimizer(1e-4, beta1=.5)

train_gen_op =  opt_g.minimize(gen_loss, var_list=qvars)
train_disc_op = opt_g.minimize(disc_loss, var_list=dvars)

plc_float = tf.placeholder(tf.float32)
plc_float_r = tf.placeholder(tf.float32)
disc_loss_calc2 = -tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=plc_float_r, labels=tf.ones_like(plc_float_r)) +
    tf.nn.sigmoid_cross_entropy_with_logits(logits=plc_float, labels=tf.zeros_like(plc_float)))
disc_loss_calc1 = -tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=plc_float, labels=tf.zeros_like(plc_float)))
with tf.variable_scope('worst_calc', reuse=tf.AUTO_REUSE):
    new_opt = tf.train.AdamOptimizer(1e-3, beta1=.5)
    df = discriminator(x_g, n_hidden=params['n_hidden_disc'], scope="discriminatWorst")
    dr = discriminator(x_r, n_hidden=params['n_hidden_disc'], scope="discriminatWorst")
    disc_loss_worst = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=dr, labels=tf.ones_like(dr)) +
    tf.nn.sigmoid_cross_entropy_with_logits(logits=df, labels=tf.zeros_like(df)))
    t_vars = tf.global_variables()
    d_vars_worst = [var for var in t_vars if 'discriminatWorst' in var.name]

    find_worst_d = new_opt.minimize(disc_loss_worst, var_list=d_vars_worst)
t_vars = tf.global_variables()
d_init = [var for var in t_vars if 'worst_calc' in var.name]



with tf.variable_scope('worst_calc_gen', reuse=tf.AUTO_REUSE):
    new_opt_gen = tf.train.AdamOptimizer(1e-3, beta1=.5)
    x_w = generator(z, params['latent_dim'], params['n_hidden_gen'], scope='generator_worst')
    
    
dr_w = discriminator(x_r, n_hidden=params['n_hidden_disc'], scope="discriminator1", reuse=True)
df_w = discriminator(x_w, n_hidden=params['n_hidden_disc'], scope="discriminator1", reuse=True)
with tf.variable_scope('worst_calc_gen', reuse=tf.AUTO_REUSE):
    gen_loss_worst = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=df_w, labels=tf.ones_like(df_w)))
    t_vars = tf.global_variables()
    g_vars_worst = [var for var in t_vars if 'generator_worst' in var.name]
    print g_vars_worst

    find_worst_g = new_opt_gen.minimize(gen_loss_worst, var_list=g_vars_worst)
    



# ## Construct function for swapping variables

# In[7]:


curr_to_tmp=[]

d_vars_tmp = [var for var in t_vars if  'discriminatWorst' in var.name]
d_vars_0 = [var for var in t_vars if 'discriminator1/' in var.name]
g_vars_tmp = [var for var in t_vars if  'generator_worst' in var.name]
g_vars_0 = [var for var in t_vars if 'generator1/' in var.name]
for j in range(0, len(d_vars_tmp)):
#     print d_vars_tmp[j].name
#     print d_vars_0[j].name
    curr_to_tmp.append(d_vars_tmp[j].assign(d_vars_0[j]))
for j in range(0, len(g_vars_tmp)):
    curr_to_tmp.append(g_vars_tmp[j].assign(g_vars_0[j]))

current_to_tmp = tf.group(*curr_to_tmp)


# In[8]:


sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())


# ## Train model

# In[9]:


all_scores = []
duality_gap_mean = []
from tqdm import tqdm
import collections
fs = []
np_samples = []
d_losses=[]
g_losses=[]

total_batch = 1000
#  Training cycle
for epoch in tqdm(xrange(25)):
    xx = np.vstack([sess.run(x_g) for _ in range(5)])
    yy= np.vstack([sess.run(x_r) for _ in range(5)])
    fig_= figure(figsize=(5,5), facecolor='w')

    scatter(xx[:, 0], xx[:, 1],
            edgecolor='none', alpha=0.5)
    scatter(yy[:, 0], yy[:, 1], c='g', edgecolor='none')
    show()
       
#     Loop over all batches
    for i in range(total_batch):
        loss_cum, _, _ = sess.run([[gen_loss, disc_loss], train_gen_op,train_disc_op])
        if (i%100==0):
            print ("loss d")
            print (loss_cum[0])
            d_losses.append(loss_cum[0])
            print ("loss g")
            print (loss_cum[1])
            g_losses.append(loss_cum[1])
    sess.run(current_to_tmp)
    for j in range(0, 1000):
        sess.run(find_worst_d)
    df_final = sess.run(df)
    dr_final = sess.run(dr)
    worst_minmax = sess.run(disc_loss_calc2, feed_dict={plc_float:df_final, plc_float_r:dr_final})
    for j in range(0, 1000):
        sess.run(find_worst_g)
    df_final = sess.run(df_w)
    dr_final = sess.run(dr_w)
#         print sess.run(tf.reduce_mean(tf.nn.sigmoid(discriminator(sess.run(x_w), n_hidden=params['n_hidden_disc'], scope="discriminator1", reuse=True))))
#     print worst_maxmin
#         print sess.run(tf.reduce_mean(tf.nn.sigmoid(discriminator(sess.run(x_r), n_hidden=params['n_hidden_disc'], scope="discriminator1", reuse=True))))


    df_final = sess.run(df_w)
    dr_final = sess.run(dr_w)
    worst_maxmin = sess.run(disc_loss_calc2, feed_dict={plc_float:df_final, plc_float_r:dr_final})

    dualitygap_score = worst_minmax - worst_maxmin
    all_scores.append(dualitygap_score)
    np_samples.append(np.vstack([sess.run(x_g) for _ in xrange(10)]))
                


# In[13]:


import seaborn as sns
xmax = 5
viz_every=1000
np_samples_ = np_samples[::2]
cols = len(np_samples_)
bg_color  = sns.color_palette('Blues', n_colors=256)[0]
figure(figsize=(10*cols, 10))
for i, samps in enumerate(np_samples_):
    if i == 0:
        ax = subplot(1,cols,1)
    else:
        subplot(1,cols,i+1, sharex=ax, sharey=ax)
    ax2 = sns.kdeplot(samps[:, 0], samps[:, 1], shade=True, cmap='Blues', n_levels=20, clip=[[-xmax,xmax]]*5)
    ax2.set_axis_bgcolor(bg_color)
    xticks([]); yticks([])
    title('step %d'%(i*viz_every))
# ax.set_ylabel('%d unrolling steps'%params['unrolling_steps'])
gcf().tight_layout()


# In[10]:



plt.plot(range(0, len(all_scores)), all_scores)
plt.ylabel('duality gap')
plt.xlabel('epochs')
plt.show()



# In[11]:


plt.plot(range(0, len(d_losses)), d_losses)
plt.ylabel('d_losses')
plt.xlabel('epochs')
plt.show()


# In[12]:


plt.plot(range(0, len(g_losses)), g_losses)
plt.ylabel('g_losses')
plt.xlabel('epochs')
plt.show()


# In[ ]:


'''Sample 2500 points'''
xx = np.vstack([sess.run(x_g) for _ in range(5)])
yy= np.vstack([sess.run(x_r) for _ in range(5)])


'''KDE Plots'''
sns.set(font_scale=2)
f, (ax1,ax2) = plt.subplots(2,figsize=(10, 15))
cmap = sns.cubehelix_palette(as_cmap=True, dark=0, light=1, reverse=True)
sns.kdeplot(xx[:, 0], xx[:,1], cmap=cmap, ax=ax1, n_levels=100, shade=True, clip=[[-6, 6]]*2)
sns.kdeplot(yy[:, 0], yy[:,1], cmap=cmap,ax=ax2, n_levels=100, shade=True, clip=[[-6, 6]]*2)


'''Evaluation'''
MEANS = real_mus


l2_store=[]
for x_ in xx:
    l2_store.append([np.sum((x_-i)**2)  for i in MEANS])
    
mode=numpy.argmin(l2_store,1).flatten().tolist()
dis_ = [l2_store[j][i] for j,i in enumerate(mode)]
mode_counter = [mode[i] for i in range(len(mode)) if numpy.sqrt(dis_[i])<=0.15]

print 'Number of Modes Captured: ',len(collections.Counter(mode_counter))
print 'Number of Points Falling Within 3 std. of the Nearest Mode ',numpy.sum(collections.Counter(mode_counter).values())

